balanced_accuracy_score#
Balanced accuracy is the macro-average of recall: it computes recall separately for each class and then averages across classes.
Quick import#
from sklearn.metrics import balanced_accuracy_score
It is especially useful when:
the dataset is imbalanced
you want to treat classes equally (e.g., you care about minority recall as much as majority recall)
Goals
Build intuition (why accuracy can be misleading)
Derive the metric for binary and multiclass classification
Implement
balanced_accuracy_scorefrom scratch in NumPyVisualize per-class recall and threshold effects (Plotly)
Use balanced accuracy to guide a simple optimization loop (from-scratch logistic regression)
Prerequisites
Confusion matrix, recall (TPR), specificity (TNR)
Probabilistic classifiers (logistic regression outputs probabilities)
import numpy as np
import plotly.express as px
import plotly.graph_objects as go
import os
import plotly.io as pio
from plotly.subplots import make_subplots
from scipy.special import expit
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import balanced_accuracy_score as sk_balanced_accuracy_score
from sklearn.model_selection import train_test_split
pio.renderers.default = os.environ.get("PLOTLY_RENDERER", "notebook")
pio.templates.default = "plotly_white"
rng = np.random.default_rng(42)
np.set_printoptions(precision=4, suppress=True)
1) Why not plain accuracy?#
Accuracy is
If one class dominates, a model can achieve high accuracy by mostly predicting the majority class.
Example: 99% negatives, 1% positives.
A classifier that predicts always negative gets 99% accuracy.
But it has 0% recall on the positive class.
Balanced accuracy fixes this by computing recall per class and averaging them, so each class contributes equally.
2) Definition (binary classification)#
For \(y \in \{0,1\}\) and predictions \(\hat{y} \in \{0,1\}\):
\(\hat{y}=0\) |
\(\hat{y}=1\) |
|
|---|---|---|
\(y=0\) |
TN |
FP |
\(y=1\) |
FN |
TP |
Two key rates:
Recall / sensitivity / TPR
Specificity / TNR
Balanced accuracy is the mean of these two:
It is also related to the balanced error rate (BER):
3) Definition (multiclass + adjusted)#
For \(K\) classes, balanced accuracy is the average recall per class (macro recall):
So balanced accuracy is exactly:
Adjusted balanced accuracy#
scikit-learn also offers a chance-corrected version:
A classifier that effectively behaves like random guessing tends toward \(\text{BA}_{\text{adj}} \approx 0\).
Perfect classification gives \(\text{BA}_{\text{adj}} = 1\).
Worse-than-chance can be negative.
def accuracy_score_np(y_true, y_pred, sample_weight=None) -> float:
y_true = np.asarray(y_true)
y_pred = np.asarray(y_pred)
correct = (y_true == y_pred).astype(float)
if sample_weight is None:
return float(correct.mean())
w = np.asarray(sample_weight, dtype=float)
return float(np.sum(w * correct) / np.sum(w))
def per_class_recall_np(
y_true,
y_pred,
labels=None,
sample_weight=None,
zero_division: float = 0.0,
):
# Per-class recall:
# recall_k = (# predicted as k among true k) / (# true k)
y_true = np.asarray(y_true)
y_pred = np.asarray(y_pred)
if labels is None:
labels = np.unique(y_true)
labels = np.asarray(labels)
if sample_weight is None:
sample_weight = np.ones_like(y_true, dtype=float)
else:
sample_weight = np.asarray(sample_weight, dtype=float)
recalls = np.empty(len(labels), dtype=float)
for i, cls in enumerate(labels):
mask = y_true == cls
denom = float(sample_weight[mask].sum())
if denom == 0.0:
recalls[i] = zero_division
else:
num = float(sample_weight[mask & (y_pred == cls)].sum())
recalls[i] = num / denom
return recalls, labels
def balanced_accuracy_score_np(
y_true,
y_pred,
*,
labels=None,
sample_weight=None,
adjusted: bool = False,
zero_division: float = 0.0,
) -> float:
recalls, labels_used = per_class_recall_np(
y_true,
y_pred,
labels=labels,
sample_weight=sample_weight,
zero_division=zero_division,
)
score = float(np.mean(recalls))
if not adjusted:
return score
n_classes = len(labels_used)
if n_classes <= 1:
return 1.0
chance = 1.0 / n_classes
return float((score - chance) / (1.0 - chance))
def confusion_matrix_np(y_true, y_pred, labels=None, sample_weight=None):
# Small confusion-matrix helper (mainly for plotting)
y_true = np.asarray(y_true)
y_pred = np.asarray(y_pred)
if labels is None:
labels = np.unique(np.concatenate([y_true, y_pred]))
labels = np.asarray(labels)
label_to_index = {label: i for i, label in enumerate(labels)}
true_idx = np.array([label_to_index.get(v, -1) for v in y_true], dtype=int)
pred_idx = np.array([label_to_index.get(v, -1) for v in y_pred], dtype=int)
if sample_weight is None:
sample_weight = np.ones_like(true_idx, dtype=float)
else:
sample_weight = np.asarray(sample_weight, dtype=float)
cm = np.zeros((len(labels), len(labels)), dtype=float)
valid = (true_idx >= 0) & (pred_idx >= 0)
np.add.at(cm, (true_idx[valid], pred_idx[valid]), sample_weight[valid])
return cm, labels
# quick sanity check vs scikit-learn
_y_true = np.array([0, 0, 0, 1, 1, 1])
_y_pred = np.array([0, 0, 1, 0, 1, 1])
print('ours:', balanced_accuracy_score_np(_y_true, _y_pred))
print('sklearn:', sk_balanced_accuracy_score(_y_true, _y_pred))
ours: 0.6666666666666666
sklearn: 0.6666666666666666
4) Worked example: “always predict the majority class”#
Let’s build an extremely imbalanced dataset and evaluate a trivial classifier.
990 negatives (class 0)
10 positives (class 1)
predictions: always class 0
This classifier gets excellent accuracy, but poor minority-class performance.
n_neg, n_pos = 990, 10
y_true = np.array([0] * n_neg + [1] * n_pos)
y_pred = np.zeros_like(y_true)
acc = accuracy_score_np(y_true, y_pred)
bal = balanced_accuracy_score_np(y_true, y_pred)
bal_adj = balanced_accuracy_score_np(y_true, y_pred, adjusted=True)
recalls, labels = per_class_recall_np(y_true, y_pred)
print(f"accuracy: {acc:.4f}")
print(f"balanced accuracy: {bal:.4f}")
print(f"adjusted BA: {bal_adj:.4f}")
print("per-class recall:", dict(zip(labels.tolist(), recalls.tolist())))
accuracy: 0.9900
balanced accuracy: 0.5000
adjusted BA: 0.0000
per-class recall: {0: 1.0, 1: 0.0}
cm, cm_labels = confusion_matrix_np(y_true, y_pred)
fig = px.imshow(
cm,
text_auto=True,
color_continuous_scale="Blues",
x=[f"pred={l}" for l in cm_labels],
y=[f"true={l}" for l in cm_labels],
)
fig.update_layout(title="Confusion matrix: always predicting class 0")
fig.show()
fig = go.Figure(
data=[
go.Bar(
x=[str(l) for l in labels],
y=recalls,
text=[f"{r:.2f}" for r in recalls],
textposition="auto",
)
]
)
fig.update_layout(
title="Per-class recall (balanced accuracy is the mean of these)",
xaxis_title="class",
yaxis_title="recall",
yaxis=dict(range=[0, 1]),
)
fig.show()
5) Threshold dependence (probabilities → labels)#
Balanced accuracy is defined on hard predictions (\(\hat{y}\)).
If your model outputs probabilities \(p(x) = P(y=1\mid x)\), you still need a decision threshold \(t\):
Changing \(t\) changes the confusion matrix, hence recall per class, hence balanced accuracy.
# A simple probability simulation (overlapping scores + class imbalance)
n_neg, n_pos = 2000, 100
y_true = np.array([0] * n_neg + [1] * n_pos)
# Negatives tend to have lower predicted probabilities, positives higher, but overlapping.
p_neg = rng.beta(2.0, 8.0, size=n_neg)
p_pos = rng.beta(5.0, 5.0, size=n_pos)
proba = np.concatenate([p_neg, p_pos])
# Shuffle together
perm = rng.permutation(len(y_true))
y_true = y_true[perm]
proba = proba[perm]
thresholds = np.linspace(0.0, 1.0, 401)
accs = np.empty_like(thresholds)
bals = np.empty_like(thresholds)
for i, t in enumerate(thresholds):
y_pred = (proba >= t).astype(int)
accs[i] = accuracy_score_np(y_true, y_pred)
bals[i] = balanced_accuracy_score_np(y_true, y_pred)
best_t = float(thresholds[np.argmax(bals)])
fig = go.Figure()
fig.add_trace(go.Scatter(x=thresholds, y=accs, name="accuracy", mode="lines"))
fig.add_trace(go.Scatter(x=thresholds, y=bals, name="balanced accuracy", mode="lines"))
fig.add_vline(x=best_t, line_dash="dash", line_color="black")
fig.update_layout(
title=f"Accuracy vs balanced accuracy as a function of threshold (best BA at t={best_t:.3f})",
xaxis_title="threshold t",
yaxis_title="score",
yaxis=dict(range=[0, 1]),
)
fig.show()
6) Using balanced accuracy to guide an optimization loop (logistic regression)#
Balanced accuracy is not differentiable w.r.t. model parameters because it depends on discrete decisions (argmax / threshold).
In practice, we typically:
Train a probabilistic classifier with a differentiable loss (e.g., log loss)
Use balanced accuracy as a model selection criterion:
choose hyperparameters
choose early-stopping epoch
choose decision threshold
A common surrogate that often improves balanced accuracy is to train with class weights (roughly: make each class contribute equally to the loss).
Below we train logistic regression from scratch in two ways:
Unweighted log loss
Class-weighted log loss (“balanced” weights)
…and monitor validation balanced accuracy for early stopping.
# Synthetic 2D imbalanced dataset (mild overlap)
n0, n1 = 1200, 80
X0 = rng.normal(loc=(0.0, 0.0), scale=1.0, size=(n0, 2))
X1 = rng.normal(loc=(1.2, 1.2), scale=1.0, size=(n1, 2))
X = np.vstack([X0, X1])
y = np.concatenate([np.zeros(n0, dtype=int), np.ones(n1, dtype=int)])
perm = rng.permutation(len(y))
X, y = X[perm], y[perm]
X_train, X_val, y_train, y_val = train_test_split(
X, y, test_size=0.25, random_state=0, stratify=y
)
fig = px.scatter(
x=X[:, 0],
y=X[:, 1],
color=y.astype(str),
opacity=0.7,
title="Synthetic imbalanced dataset",
labels={"x": "x1", "y": "x2", "color": "class"},
)
fig.show()
print('train class counts:', {0: int((y_train==0).sum()), 1: int((y_train==1).sum())})
print('val class counts: ', {0: int((y_val==0).sum()), 1: int((y_val==1).sum())})
train class counts: {0: 900, 1: 60}
val class counts: {0: 300, 1: 20}
def standardize_fit(X):
mean = X.mean(axis=0)
std = X.std(axis=0) + 1e-12
return mean, std
def standardize_transform(X, mean, std):
return (X - mean) / std
def add_intercept(X):
return np.c_[np.ones((X.shape[0], 1)), X]
def predict_proba_logreg(X, w):
Xb = add_intercept(X)
return expit(Xb @ w)
def log_loss_binary(y, p, sample_weight=None, eps: float = 1e-12) -> float:
y = np.asarray(y)
p = np.clip(np.asarray(p), eps, 1.0 - eps)
per_sample = -(y * np.log(p) + (1.0 - y) * np.log(1.0 - p))
if sample_weight is None:
return float(per_sample.mean())
w = np.asarray(sample_weight, dtype=float)
return float(np.sum(w * per_sample) / np.sum(w))
def fit_logreg_gd(
X_train,
y_train,
X_val,
y_val,
*,
lr: float = 0.2,
n_epochs: int = 400,
l2: float = 1e-2,
sample_weight=None,
):
# Binary logistic regression with (optional) sample weights + early stopping on val BA
Xb = add_intercept(X_train)
n, d = Xb.shape
if sample_weight is None:
sample_weight = np.ones(n, dtype=float)
else:
sample_weight = np.asarray(sample_weight, dtype=float)
sw_sum = float(sample_weight.sum())
w = np.zeros(d, dtype=float)
history = {
"train_loss": [],
"val_acc": [],
"val_bal_acc": [],
}
best = {
"epoch": -1,
"val_bal_acc": -np.inf,
"w": w.copy(),
}
for epoch in range(n_epochs):
# forward + gradient on train
p_train = expit(Xb @ w)
grad = (Xb.T @ (sample_weight * (p_train - y_train))) / sw_sum
grad[1:] += l2 * w[1:]
w = w - lr * grad
# metrics
p_train = expit(Xb @ w)
train_loss = log_loss_binary(y_train, p_train, sample_weight=sample_weight) + 0.5 * l2 * float(
np.sum(w[1:] ** 2)
)
p_val = predict_proba_logreg(X_val, w)
y_val_hat = (p_val >= 0.5).astype(int)
val_acc = accuracy_score_np(y_val, y_val_hat)
val_bal_acc = balanced_accuracy_score_np(y_val, y_val_hat)
history["train_loss"].append(train_loss)
history["val_acc"].append(val_acc)
history["val_bal_acc"].append(val_bal_acc)
if val_bal_acc > best["val_bal_acc"]:
best = {"epoch": epoch, "val_bal_acc": val_bal_acc, "w": w.copy()}
return best["w"], history, best
# Standardize features (important for GD stability)
mean, std = standardize_fit(X_train)
X_train_s = standardize_transform(X_train, mean, std)
X_val_s = standardize_transform(X_val, mean, std)
# Unweighted training
w_unw, hist_unw, best_unw = fit_logreg_gd(X_train_s, y_train, X_val_s, y_val)
# Balanced class weights: each class gets ~50% of total weight
n_train = len(y_train)
n_pos = int((y_train == 1).sum())
n_neg = int((y_train == 0).sum())
w_pos = n_train / (2.0 * n_pos)
w_neg = n_train / (2.0 * n_neg)
sw_bal = np.where(y_train == 1, w_pos, w_neg)
w_wt, hist_wt, best_wt = fit_logreg_gd(X_train_s, y_train, X_val_s, y_val, sample_weight=sw_bal)
print('best epoch (unweighted):', best_unw['epoch'], 'val BA:', f"{best_unw['val_bal_acc']:.4f}")
print('best epoch (weighted): ', best_wt['epoch'], 'val BA:', f"{best_wt['val_bal_acc']:.4f}")
best epoch (unweighted): 179 val BA: 0.5250
best epoch (weighted): 53 val BA: 0.8367
epochs = np.arange(1, len(hist_unw["train_loss"]) + 1)
fig = make_subplots(
rows=1,
cols=3,
subplot_titles=("Train log loss", "Validation accuracy", "Validation balanced accuracy"),
)
for name, hist in [("unweighted", hist_unw), ("class-weighted", hist_wt)]:
fig.add_trace(
go.Scatter(x=epochs, y=hist["train_loss"], name=f"{name} loss", mode="lines"),
row=1,
col=1,
)
fig.add_trace(
go.Scatter(x=epochs, y=hist["val_acc"], name=f"{name} acc", mode="lines"),
row=1,
col=2,
)
fig.add_trace(
go.Scatter(x=epochs, y=hist["val_bal_acc"], name=f"{name} BA", mode="lines"),
row=1,
col=3,
)
fig.update_layout(height=350, width=1100, title="Training curves (early stopping uses validation BA)")
fig.update_yaxes(range=[0, 1], row=1, col=2)
fig.update_yaxes(range=[0, 1], row=1, col=3)
fig.show()
def best_threshold_for_balanced_accuracy(y_true, proba, thresholds):
best = {"t": None, "ba": -np.inf}
for t in thresholds:
y_pred = (proba >= t).astype(int)
ba = balanced_accuracy_score_np(y_true, y_pred)
if ba > best["ba"]:
best = {"t": float(t), "ba": float(ba)}
return best
thresholds = np.linspace(0.0, 1.0, 401)
p_unw = predict_proba_logreg(X_val_s, w_unw)
p_wt = predict_proba_logreg(X_val_s, w_wt)
best_t_unw = best_threshold_for_balanced_accuracy(y_val, p_unw, thresholds)
best_t_wt = best_threshold_for_balanced_accuracy(y_val, p_wt, thresholds)
print('best threshold (unweighted):', best_t_unw)
print('best threshold (weighted): ', best_t_wt)
# Visualize BA(t)
ba_unw = [balanced_accuracy_score_np(y_val, (p_unw >= t).astype(int)) for t in thresholds]
ba_wt = [balanced_accuracy_score_np(y_val, (p_wt >= t).astype(int)) for t in thresholds]
fig = go.Figure()
fig.add_trace(go.Scatter(x=thresholds, y=ba_unw, name="unweighted", mode="lines"))
fig.add_trace(go.Scatter(x=thresholds, y=ba_wt, name="class-weighted", mode="lines"))
fig.add_vline(x=best_t_unw["t"], line_dash="dash", line_color="#1f77b4")
fig.add_vline(x=best_t_wt["t"], line_dash="dash", line_color="#ff7f0e")
fig.update_layout(
title="Validation balanced accuracy as a function of the decision threshold",
xaxis_title="threshold t",
yaxis_title="balanced accuracy",
yaxis=dict(range=[0, 1]),
)
fig.show()
best threshold (unweighted): {'t': 0.1525, 'ba': 0.8433333333333334}
best threshold (weighted): {'t': 0.62, 'ba': 0.8416666666666667}
def summarize_threshold(y_true, proba, t):
y_pred = (proba >= t).astype(int)
acc = accuracy_score_np(y_true, y_pred)
ba = balanced_accuracy_score_np(y_true, y_pred)
recalls, labels = per_class_recall_np(y_true, y_pred)
cm, _ = confusion_matrix_np(y_true, y_pred, labels=np.array([0, 1]))
return {
"t": float(t),
"acc": float(acc),
"ba": float(ba),
"recalls": dict(zip(labels.tolist(), recalls.tolist())),
"cm": cm,
}
summaries = {
"unweighted @0.5": summarize_threshold(y_val, p_unw, 0.5),
"unweighted @t*": summarize_threshold(y_val, p_unw, best_t_unw["t"]),
"weighted @0.5": summarize_threshold(y_val, p_wt, 0.5),
"weighted @t*": summarize_threshold(y_val, p_wt, best_t_wt["t"]),
}
for k, v in summaries.items():
print(k, {"t": v["t"], "acc": v["acc"], "ba": v["ba"], "recalls": v["recalls"]})
# Confusion matrices (2x2): rows=methods, cols=threshold choice
fig = make_subplots(
rows=2,
cols=2,
subplot_titles=(
"Unweighted @0.5",
"Unweighted @t*",
"Weighted @0.5",
"Weighted @t*",
),
)
items = [
(1, 1, summaries["unweighted @0.5"]),
(1, 2, summaries["unweighted @t*"]),
(2, 1, summaries["weighted @0.5"]),
(2, 2, summaries["weighted @t*"]),
]
for r, c, s in items:
cm = s["cm"]
fig.add_trace(
go.Heatmap(
z=cm,
x=["pred=0", "pred=1"],
y=["true=0", "true=1"],
colorscale="Blues",
showscale=False,
text=cm.astype(int),
texttemplate="%{text}",
),
row=r,
col=c,
)
fig.update_layout(height=650, width=900, title="Validation confusion matrices")
fig.show()
unweighted @0.5 {'t': 0.5, 'acc': 0.940625, 'ba': 0.525, 'recalls': {0: 1.0, 1: 0.05}}
unweighted @t* {'t': 0.1525, 'acc': 0.88125, 'ba': 0.8433333333333334, 'recalls': {0: 0.8866666666666667, 1: 0.8}}
weighted @0.5 {'t': 0.5, 'acc': 0.7375, 'ba': 0.8366666666666667, 'recalls': {0: 0.7233333333333334, 1: 0.95}}
weighted @t* {'t': 0.62, 'acc': 0.834375, 'ba': 0.8416666666666667, 'recalls': {0: 0.8333333333333334, 1: 0.85}}
# Decision boundary visualization (in original feature space)
def decision_boundary_figure(X_val, y_val, w, mean, std, threshold: float, title: str):
x1_min, x1_max = X_val[:, 0].min() - 1.0, X_val[:, 0].max() + 1.0
x2_min, x2_max = X_val[:, 1].min() - 1.0, X_val[:, 1].max() + 1.0
xs = np.linspace(x1_min, x1_max, 200)
ys = np.linspace(x2_min, x2_max, 200)
xx, yy = np.meshgrid(xs, ys)
grid = np.c_[xx.ravel(), yy.ravel()]
grid_s = standardize_transform(grid, mean, std)
p = predict_proba_logreg(grid_s, w).reshape(xx.shape)
fig = go.Figure()
fig.add_trace(
go.Contour(
x=xs,
y=ys,
z=p,
contours=dict(start=threshold, end=threshold, size=1, coloring="lines"),
line=dict(color="black", width=3),
showscale=False,
name="decision boundary",
)
)
fig.add_trace(
go.Scatter(
x=X_val[:, 0],
y=X_val[:, 1],
mode="markers",
marker=dict(
size=6,
color=y_val,
colorscale=[[0, "#1f77b4"], [1, "#d62728"]],
opacity=0.7,
line=dict(width=0),
),
name="validation points",
)
)
fig.update_layout(
title=title,
xaxis_title="x1",
yaxis_title="x2",
height=450,
width=500,
legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1),
)
return fig
fig1 = decision_boundary_figure(
X_val,
y_val,
w_unw,
mean,
std,
threshold=best_t_unw["t"],
title=f"Unweighted logistic regression (threshold t*={best_t_unw['t']:.2f})",
)
fig2 = decision_boundary_figure(
X_val,
y_val,
w_wt,
mean,
std,
threshold=best_t_wt["t"],
title=f"Class-weighted logistic regression (threshold t*={best_t_wt['t']:.2f})",
)
fig = make_subplots(rows=1, cols=2, subplot_titles=(fig1.layout.title.text, fig2.layout.title.text))
for tr in fig1.data:
fig.add_trace(tr, row=1, col=1)
for tr in fig2.data:
fig.add_trace(tr, row=1, col=2)
fig.update_layout(height=450, width=1050, title="Decision boundary tuned for balanced accuracy")
fig.update_xaxes(title_text="x1", row=1, col=1)
fig.update_yaxes(title_text="x2", row=1, col=1)
fig.update_xaxes(title_text="x1", row=1, col=2)
fig.update_yaxes(title_text="x2", row=1, col=2)
fig.show()
7) Practical scikit-learn usage#
Metric#
from sklearn.metrics import balanced_accuracy_score
balanced_accuracy_score(y_true, y_pred)
balanced_accuracy_score(y_true, y_pred, adjusted=True)
Model selection#
In
GridSearchCV/cross_val_score, usescoring="balanced_accuracy".Many estimators support
class_weight="balanced", which often improves balanced accuracy.
# scikit-learn comparison on the same dataset
clf_unw = LogisticRegression(max_iter=2000)
clf_wt = LogisticRegression(max_iter=2000, class_weight="balanced")
clf_unw.fit(X_train, y_train)
clf_wt.fit(X_train, y_train)
pred_unw = clf_unw.predict(X_val)
pred_wt = clf_wt.predict(X_val)
print('sklearn unweighted BA:', sk_balanced_accuracy_score(y_val, pred_unw))
print('sklearn weighted BA: ', sk_balanced_accuracy_score(y_val, pred_wt))
sklearn unweighted BA: 0.5483333333333333
sklearn weighted BA: 0.8183333333333334
8) Pros, cons, and when to use it#
Pros#
Handles class imbalance better than accuracy (each class contributes equally).
Easy to interpret: it is the average recall per class.
Works naturally for multiclass problems.
Cons / limitations#
Ignores precision: you can increase recall (and BA) by predicting a class more often, possibly creating many false positives.
Threshold-dependent: with probabilistic outputs, you may need to tune the decision threshold.
Not differentiable → typically used for evaluation/model selection, not as a direct training loss.
Equal class weighting may not match real costs (some false negatives/positives may matter more than others).
Good use-cases#
Imbalanced classification where you want good recall for every class.
Settings where the minority class is important and accuracy would be misleading.
If you care about ranking probabilities rather than hard labels, consider threshold-free metrics such as AUROC or Average Precision (PR AUC).
9) Exercises#
Compute balanced accuracy by hand for a binary confusion matrix.
Implement
balanced_accuracy_score_np(..., sample_weight=...)tests:give higher weight to a subset of samples
verify it matches
sklearn.metrics.balanced_accuracy_score(..., sample_weight=...).
On the synthetic dataset above:
compare accuracy vs balanced accuracy as you vary the threshold
find a threshold that maximizes balanced accuracy and report the per-class recalls.
Extend the notebook to multiclass:
generate 3 classes with imbalance
compute per-class recalls and balanced accuracy
visualize the confusion matrix.
References#
scikit-learn docs: https://scikit-learn.org/stable/modules/generated/sklearn.metrics.balanced_accuracy_score.html
scikit-learn user guide (model evaluation): https://scikit-learn.org/stable/modules/model_evaluation.html